{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TextGeneration-Transformers-PythonCodeTutorial.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "6bjli5Z7ZEVh"
      },
      "source": [
        "!pip install transformers"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SPADZcRSY-3Y"
      },
      "source": [
        "from transformers import pipeline"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k0zHPjIkqcEx"
      },
      "source": [
        "# download & load GPT-2 model\n",
        "gpt2_generator = pipeline('text-generation', model='gpt2')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "me1PAIvlqwKf"
      },
      "source": [
        "# generate 3 different sentences\n",
        "# results are sampled from the top 50 candidates\n",
        "sentences = gpt2_generator(\"To be honest, neural networks\", do_sample=True, top_k=50, temperature=0.6, max_length=128, num_return_sequences=3)\n",
        "for sentence in sentences:\n",
        "  print(sentence[\"generated_text\"])\n",
        "  print(\"=\"*50)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aXI92oauZCD4"
      },
      "source": [
        "# download & load GPT-J model! It's 22.5GB in size\n",
        "gpt_j_generator = pipeline('text-generation', model='EleutherAI/gpt-j-6B')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EaOAqXnXtOI0"
      },
      "source": [
        "# generate sentences with TOP-K sampling\n",
        "sentences = gpt_j_generator(\"To be honest, robots will\", do_sample=True, top_k=50, temperature=0.6, max_length=128, num_return_sequences=3)\n",
        "for sentence in sentences:\n",
        "  print(sentence[\"generated_text\"])\n",
        "  print(\"=\"*50)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6N5qFdcFZG1v"
      },
      "source": [
        "# generate Python Code!\n",
        "print(gpt_j_generator(\n",
        "\"\"\"\n",
        "import os\n",
        "# make a list of all african countries\n",
        "\"\"\",\n",
        "    do_sample=True, top_k=10, temperature=0.05, max_length=256)[0][\"generated_text\"])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-TOTvHiwwbK-"
      },
      "source": [
        "print(gpt_j_generator(\n",
        "\"\"\"\n",
        "import cv2\n",
        "\n",
        "image = \"image.png\"\n",
        "\n",
        "# load the image and flip it\n",
        "\"\"\",\n",
        "    do_sample=True, top_k=10, temperature=0.05, max_length=256)[0][\"generated_text\"])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_52OftmglAAv"
      },
      "source": [
        "# complete bash script!\n",
        "print(gpt_j_generator(\n",
        "\"\"\"\n",
        "# get .py files in /opt directory\n",
        "ls *.py /opt\n",
        "# get public ip address\n",
        "\"\"\", max_length=256, top_k=50, temperature=0.05, do_sample=True)[0][\"generated_text\"])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2x527AykVquF"
      },
      "source": [
        "# generating bash script!\n",
        "print(gpt_j_generator(\n",
        "\"\"\"\n",
        "# update the repository\n",
        "sudo apt-get update\n",
        "# install and start nginx\n",
        "\"\"\", max_length=128, top_k=50, temperature=0.1, do_sample=True)[0][\"generated_text\"])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "elK4JyyxwCPM"
      },
      "source": [
        "# Java code!\n",
        "print(gpt_j_generator(\n",
        "\"\"\"\n",
        "public class Test {\n",
        "\n",
        "public static void main(String[] args){\n",
        "  // printing the first 20 fibonacci numbers\n",
        "\"\"\", max_length=128, top_k=50, temperature=0.1, do_sample=True)[0][\"generated_text\"])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0US1Tv5xh-F2"
      },
      "source": [
        "# LATEX!\n",
        "print(gpt_j_generator(\n",
        "r\"\"\"\n",
        "% list of Asian countries\n",
        "\\begin{enumerate}\n",
        "\"\"\", max_length=128, top_k=15, temperature=0.1, do_sample=True)[0][\"generated_text\"])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "clkMMnsgh_YF"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}